# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
    This script is supposed to process downloaded scannet files into a format usable with ScanNet dataset.
    Load Scannet scenes with vertices and ground truth labels for semantic and
    instance segmentations.
    Code is as-is and hasn't been tested with edge cases/other data.

    # Modified from
    # https://github.com/facebookresearch/votenet/blob/master/scannet/load_scannet_data.py

    Which has:
    # MIT License
    # Copyright (c) Meta Platforms, Inc. and affiliates.
    # Permission is hereby granted, free of charge, to any person obtaining a copy
    # of this software and associated documentation files (the "Software"), to deal
    # in the Software without restriction, including without limitation the rights
    # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    # copies of the Software, and to permit persons to whom the Software is
    # furnished to do so, subject to the following conditions:

    # The above copyright notice and this permission notice shall be included in all
    # copies or substantial portions of the Software.

    # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
    # SOFTWARE.
"""
import argparse
import inspect
import json
import os
from typing import Dict, List

import numpy as np
import scannet_utils

currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))


def read_aggregation(filename):
    assert os.path.isfile(filename)
    object_id_to_segs = {}
    label_to_segs = {}
    with open(filename) as f:
        data = json.load(f)
        num_objects = len(data["segGroups"])
        for i in range(num_objects):
            object_id = data["segGroups"][i][
                "objectId"
            ]  # + 1  # instance ids should be 1-indexed
            label = data["segGroups"][i]["label"]
            segs = data["segGroups"][i]["segments"]
            object_id_to_segs[object_id] = segs
            if label in label_to_segs:
                label_to_segs[label].extend(segs)
            else:
                label_to_segs[label] = segs
    return object_id_to_segs, label_to_segs


def read_segmentation(filename):
    assert os.path.isfile(filename)
    seg_to_verts = {}
    with open(filename) as f:
        data = json.load(f)
        num_verts = len(data["segIndices"])
        for i in range(num_verts):
            seg_id = data["segIndices"][i]
            if seg_id in seg_to_verts:
                seg_to_verts[seg_id].append(i)
            else:
                seg_to_verts[seg_id] = [i]
    return seg_to_verts, num_verts


def extract_bbox(
    mesh_vertices: np.ndarray,
    object_id_to_segs: Dict[int, List[int]],
    object_id_to_label_id: Dict[int, int],
    instance_ids: np.ndarray,
):
    num_instances = len(np.unique(list(object_id_to_segs.keys())))
    instance_bboxes = np.zeros((num_instances, 8))
    for obj_id in object_id_to_segs:
        label_id = object_id_to_label_id[obj_id]
        obj_pc = mesh_vertices[instance_ids == obj_id, 0:3]
        if len(obj_pc) == 0:
            continue
        xyz_min = np.min(obj_pc, axis=0)
        xyz_max = np.max(obj_pc, axis=0)
        bbox = np.concatenate(
            [
                (xyz_min + xyz_max) / 2.0,
                xyz_max - xyz_min,
                np.array([label_id]),
                np.array([obj_id]),
            ],
        )
        # NOTE: this assumes obj_id is in 1,2,3,.,,,.NUM_INSTANCES
        instance_bboxes[obj_id - 1, :] = bbox
    return instance_bboxes


def export(
    mesh_file,
    agg_file,
    seg_file,
    meta_file,
    label_map_file,
    output_file=None,
    test_mode=False,
):
    """Export original files to vert, ins_label, sem_label and bbox file.

    Args:
        mesh_file (str): Path of the mesh_file.
        agg_file (str): Path of the agg_file.
        seg_file (str): Path of the seg_file.
        meta_file (str): Path of the meta_file.
        label_map_file (str): Path of the label_map_file.
        output_file (str): Path of the output folder.
            Default: None.
        test_mode (bool): Whether is generating test data without labels.
            Default: False.

    It returns a tuple, which contains the the following things:
        np.ndarray: Vertices of points data.
        np.ndarray: Indexes of label.
        np.ndarray: Indexes of instance.
        np.ndarray: Instance bboxes.
        dict: Map from object_id to label_id.
    """

    label_map = scannet_utils.read_label_mapping(
        label_map_file, label_from="raw_category", label_to="id"  # nyu40id
    )
    mesh_vertices = scannet_utils.read_mesh_vertices_rgb(mesh_file)

    # Load scene axis alignment matrix
    lines = open(meta_file).readlines()
    # test set data doesn't have align_matrix
    axis_align_matrix = np.eye(4)
    for line in lines:
        if "axisAlignment" in line:
            axis_align_matrix = [
                float(x) for x in line.rstrip().strip("axisAlignment = ").split(" ")
            ]
            break
    axis_align_matrix = np.array(axis_align_matrix).reshape((4, 4))

    # perform global alignment of mesh vertices
    pts = np.ones((mesh_vertices.shape[0], 4))
    pts[:, 0:3] = mesh_vertices[:, 0:3]
    pts = np.dot(pts, axis_align_matrix.transpose())  # Nx4
    aligned_mesh_vertices = np.concatenate([pts[:, 0:3], mesh_vertices[:, 3:]], axis=1)

    # Load semantic and instance labels
    if not test_mode:
        object_id_to_segs, label_to_segs = read_aggregation(agg_file)
        seg_to_verts, num_verts = read_segmentation(seg_file)
        label_ids = np.zeros(shape=(num_verts), dtype=np.uint32)
        object_id_to_label_id = {}
        for label, segs in label_to_segs.items():
            label_id = label_map[label]
            for seg in segs:
                verts = seg_to_verts[seg]
                label_ids[verts] = label_id
        instance_ids = np.zeros(shape=(num_verts), dtype=np.uint32)  # 0: unannotated
        for object_id, segs in object_id_to_segs.items():
            for seg in segs:
                verts = seg_to_verts[seg]
                instance_ids[verts] = object_id
                if object_id not in object_id_to_label_id:
                    object_id_to_label_id[object_id] = label_ids[verts][0]
        unaligned_bboxes = extract_bbox(
            mesh_vertices, object_id_to_segs, object_id_to_label_id, instance_ids
        )
        aligned_bboxes = extract_bbox(
            aligned_mesh_vertices,
            object_id_to_segs,
            object_id_to_label_id,
            instance_ids,
        )
    else:
        label_ids = None
        instance_ids = None
        unaligned_bboxes = None
        aligned_bboxes = None
        object_id_to_label_id = None

    if output_file is not None:
        np.save(output_file + "_vert.npy", mesh_vertices)
        if not test_mode:
            np.save(output_file + "_sem_label.npy", label_ids)
            np.save(output_file + "_ins_label.npy", instance_ids)
            np.save(output_file + "_unaligned_bbox.npy", unaligned_bboxes)
            np.save(output_file + "_aligned_bbox.npy", aligned_bboxes)
            np.save(output_file + "_axis_align_matrix.npy", axis_align_matrix)

    return (
        mesh_vertices,
        label_ids,
        instance_ids,
        unaligned_bboxes,
        aligned_bboxes,
        object_id_to_label_id,
        axis_align_matrix,
    )


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--scan_path",
        required=True,
        help="path to scannet scene (e.g., data/ScanNet/v2/scene0000_00",
    )
    parser.add_argument("--output_file", required=True, help="output file")
    parser.add_argument(
        "--label_map_file", required=True, help="path to scannetv2-labels.combined.tsv"
    )
    opt = parser.parse_args()

    scan_name = os.path.split(opt.scan_path)[-1]
    mesh_file = os.path.join(opt.scan_path, scan_name + "_vh_clean_2.ply")
    agg_file = os.path.join(opt.scan_path, scan_name + ".aggregation.json")
    seg_file = os.path.join(opt.scan_path, scan_name + "_vh_clean_2.0.010000.segs.json")
    meta_file = os.path.join(
        opt.scan_path, scan_name + ".txt"
    )  # includes axisAlignment info for the train set scans.
    export(
        mesh_file, agg_file, seg_file, meta_file, opt.label_map_file, opt.output_file
    )


if __name__ == "__main__":
    main()
